{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "OHoSU6uI-xIt"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import gym\n",
    "import numpy as np\n",
    "import collections\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import rl_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "TsptTKz6-xIv"
   },
   "outputs": [],
   "source": [
    "class ReplayBuffer:\n",
    "    ''' 经验回放池 '''\n",
    "    def __init__(self, capacity):\n",
    "        self.buffer = collections.deque(maxlen=capacity)  # 队列,先进先出\n",
    "\n",
    "    def add(self, state, action, reward, next_state, done):  # 将数据加入buffer\n",
    "        self.buffer.append((state, action, reward, next_state, done))\n",
    "\n",
    "    def sample(self, batch_size):  # 从buffer中采样数据,数量为batch_size\n",
    "        transitions = random.sample(self.buffer, batch_size)\n",
    "        state, action, reward, next_state, done = zip(*transitions)\n",
    "        return np.array(state), action, reward, np.array(next_state), done\n",
    "\n",
    "    def size(self):  # 目前buffer中数据的数量\n",
    "        return len(self.buffer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "_qO6koMR-xIw"
   },
   "outputs": [],
   "source": [
    "class Qnet(torch.nn.Module):\n",
    "    ''' 只有一层隐藏层的Q网络 '''\n",
    "    def __init__(self, state_dim, hidden_dim, action_dim):\n",
    "        super(Qnet, self).__init__()\n",
    "        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)\n",
    "        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))  # 隐藏层使用ReLU激活函数\n",
    "        return self.fc2(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "wxZItCX4-xIw"
   },
   "outputs": [],
   "source": [
    "class DQN:\n",
    "    ''' DQN算法 '''\n",
    "    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma,\n",
    "                 epsilon, target_update, device):\n",
    "        self.action_dim = action_dim\n",
    "        self.q_net = Qnet(state_dim, hidden_dim,\n",
    "                          self.action_dim).to(device)  # Q网络\n",
    "        # 目标网络\n",
    "        self.target_q_net = Qnet(state_dim, hidden_dim,\n",
    "                                 self.action_dim).to(device)\n",
    "        # 使用Adam优化器\n",
    "        self.optimizer = torch.optim.Adam(self.q_net.parameters(),\n",
    "                                          lr=learning_rate)\n",
    "        self.gamma = gamma  # 折扣因子\n",
    "        self.epsilon = epsilon  # epsilon-贪婪策略\n",
    "        self.target_update = target_update  # 目标网络更新频率\n",
    "        self.count = 0  # 计数器,记录更新次数\n",
    "        self.device = device\n",
    "\n",
    "    def take_action(self, state):  # epsilon-贪婪策略采取动作\n",
    "        if np.random.random() < self.epsilon:\n",
    "            action = np.random.randint(self.action_dim)\n",
    "        else:\n",
    "            state = torch.tensor([state], dtype=torch.float).to(self.device)\n",
    "            action = self.q_net(state).argmax().item()\n",
    "        return action\n",
    "\n",
    "    def update(self, transition_dict):\n",
    "        states = torch.tensor(transition_dict['states'],\n",
    "                              dtype=torch.float).to(self.device)\n",
    "        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(\n",
    "            self.device)\n",
    "        rewards = torch.tensor(transition_dict['rewards'],\n",
    "                               dtype=torch.float).view(-1, 1).to(self.device)\n",
    "        next_states = torch.tensor(transition_dict['next_states'],\n",
    "                                   dtype=torch.float).to(self.device)\n",
    "        dones = torch.tensor(transition_dict['dones'],\n",
    "                             dtype=torch.float).view(-1, 1).to(self.device)\n",
    "\n",
    "        q_values = self.q_net(states).gather(1, actions)  # Q值\n",
    "        # 下个状态的最大Q值\n",
    "        max_next_q_values = self.target_q_net(next_states).max(1)[0].view(\n",
    "            -1, 1)\n",
    "        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones\n",
    "                                                                )  # TD误差目标\n",
    "        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))  # 均方误差损失函数\n",
    "        self.optimizer.zero_grad()  # PyTorch中默认梯度会累积,这里需要显式将梯度置为0\n",
    "        dqn_loss.backward()  # 反向传播更新参数\n",
    "        self.optimizer.step()\n",
    "\n",
    "        if self.count % self.target_update == 0:\n",
    "            self.target_q_net.load_state_dict(\n",
    "                self.q_net.state_dict())  # 更新目标网络\n",
    "        self.count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    env = gym.make(\"CartPole-v1\", new_step_api=True, render_mode=\"human\")\n",
    "    episode_return = 0\n",
    "    state = env.reset()\n",
    "    done = False\n",
    "    while not done:\n",
    "        env.render()\n",
    "        action = agent.take_action(state)\n",
    "        next_state, reward, _, done, info = env.step(action)\n",
    "        state = next_state\n",
    "        episode_return += reward\n",
    "    print(episode_return)\n",
    "    env.close()\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 97226,
     "status": "ok",
     "timestamp": 1649955480772,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "acJ1letz-xIx",
    "outputId": "26487c0d-c504-44d6-eb15-5fb137b9488f"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration 0:   0%|                                                                                                                                         | 0/50 [00:00<?, ?it/s]C:\\Users\\21168\\AppData\\Local\\Temp\\ipykernel_26488\\725080296.py:24: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_new.cpp:264.)\n",
      "  state = torch.tensor([state], dtype=torch.float).to(self.device)\n",
      "C:\\Users\\21168\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\gym\\utils\\passive_env_checker.py:241: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
      "  if not isinstance(terminated, (bool, np.bool8)):\n",
      "C:\\Users\\21168\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\gym\\envs\\classic_control\\cartpole.py:179: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned terminated = True. You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior.\u001b[0m\n",
      "  logger.warn(\n",
      "Iteration 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:01<00:00,  1.22s/it, episode=50, return=11.100]\n",
      "Iteration 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:07<00:00,  1.35s/it, episode=100, return=11.000]\n",
      "Iteration 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:10<00:00,  1.42s/it, episode=150, return=11.200]\n",
      "Iteration 3:   2%|██▌                                                                                                                              | 1/50 [00:01<01:07,  1.38s/it]"
     ]
    }
   ],
   "source": [
    "lr = 2e-3\n",
    "num_episodes = 500\n",
    "hidden_dim = 128\n",
    "gamma = 0.98\n",
    "epsilon = 0.01\n",
    "target_update = 10\n",
    "buffer_size = 10000\n",
    "minimal_size = 500\n",
    "batch_size = 64\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "env_name = 'CartPole-v1'\n",
    "env = gym.make(env_name, new_step_api=True)\n",
    "random.seed(0)\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "replay_buffer = ReplayBuffer(buffer_size)\n",
    "state_dim = env.observation_space.shape[0]\n",
    "action_dim = env.action_space.n\n",
    "agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,\n",
    "            target_update, device)\n",
    "\n",
    "return_list = []\n",
    "for i in range(10):\n",
    "    with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:\n",
    "        for i_episode in range(int(num_episodes / 10)):\n",
    "            episode_return = 0\n",
    "            state = env.reset(seed=0)\n",
    "            done = False\n",
    "            while not done:\n",
    "                action = agent.take_action(state)\n",
    "                next_state, reward, _, done, info = env.step(action)\n",
    "                replay_buffer.add(state, action, reward, next_state, done)\n",
    "                state = next_state\n",
    "                episode_return += reward\n",
    "                # 当buffer数据的数量超过一定值后,才进行Q网络训练\n",
    "                if replay_buffer.size() > minimal_size:\n",
    "                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)\n",
    "                    transition_dict = {\n",
    "                        'states': b_s,\n",
    "                        'actions': b_a,\n",
    "                        'next_states': b_ns,\n",
    "                        'rewards': b_r,\n",
    "                        'dones': b_d\n",
    "                    }\n",
    "                    agent.update(transition_dict)\n",
    "            return_list.append(episode_return)\n",
    "            if (i_episode + 1) % 10 == 0:\n",
    "                pbar.set_postfix({\n",
    "                    'episode':\n",
    "                    '%d' % (num_episodes / 10 * i + i_episode + 1),\n",
    "                    'return':\n",
    "                    '%.3f' % np.mean(return_list[-10:])\n",
    "                })\n",
    "            pbar.update(1)\n",
    "\n",
    "# Iteration 0: 100%|██████████| 50/50 [00:00<00:00, 764.86it/s, episode=50,\n",
    "# return=9.300]\n",
    "# Iteration 1: 100%|██████████| 50/50 [00:04<00:00, 10.66it/s, episode=100,\n",
    "# return=12.300]\n",
    "# Iteration 2: 100%|██████████| 50/50 [00:24<00:00,  2.05it/s, episode=150,\n",
    "# return=123.000]\n",
    "# Iteration 3: 100%|██████████| 50/50 [01:25<00:00,  1.71s/it, episode=200,\n",
    "# return=153.600]\n",
    "# Iteration 4: 100%|██████████| 50/50 [01:30<00:00,  1.80s/it, episode=250,\n",
    "# return=180.500]\n",
    "# Iteration 5: 100%|██████████| 50/50 [01:24<00:00,  1.68s/it, episode=300,\n",
    "# return=185.000]\n",
    "# Iteration 6: 100%|██████████| 50/50 [01:32<00:00,  1.85s/it, episode=350,\n",
    "# return=193.900]\n",
    "# Iteration 7: 100%|██████████| 50/50 [01:31<00:00,  1.84s/it, episode=400,\n",
    "# return=196.600]\n",
    "# Iteration 8: 100%|██████████| 50/50 [01:33<00:00,  1.88s/it, episode=450,\n",
    "# return=193.800]\n",
    "# Iteration 9: 100%|██████████| 50/50 [01:34<00:00,  1.88s/it, episode=500,\n",
    "# return=200.000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 573
    },
    "executionInfo": {
     "elapsed": 698,
     "status": "ok",
     "timestamp": 1649955495697,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "AFiCxG4W-xIy",
    "outputId": "b5610f6f-8df9-4156-ecb9-e8cb2901133c"
   },
   "outputs": [],
   "source": [
    "episodes_list = list(range(len(return_list)))\n",
    "plt.plot(episodes_list, return_list)\n",
    "plt.xlabel('Episodes')\n",
    "plt.ylabel('Returns')\n",
    "plt.title('DQN on {}'.format(env_name))\n",
    "plt.show()\n",
    "\n",
    "mv_return = rl_utils.moving_average(return_list, 9)\n",
    "plt.plot(episodes_list, mv_return)\n",
    "plt.xlabel('Episodes')\n",
    "plt.ylabel('Returns')\n",
    "plt.title('DQN on {}'.format(env_name))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CDk1DgrL-xIz"
   },
   "outputs": [],
   "source": [
    "class ConvolutionalQnet(torch.nn.Module):\n",
    "    ''' 加入卷积层的Q网络 '''\n",
    "    def __init__(self, action_dim, in_channels=4):\n",
    "        super(ConvolutionalQnet, self).__init__()\n",
    "        self.conv1 = torch.nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)\n",
    "        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=4, stride=2)\n",
    "        self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1)\n",
    "        self.fc4 = torch.nn.Linear(7 * 7 * 64, 512)\n",
    "        self.head = torch.nn.Linear(512, action_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x / 255\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = F.relu(self.fc4(x))\n",
    "        return self.head(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
